package top.lingkang.finaloauth2.server.store.impl;

import top.lingkang.finaloauth2.entity.*;
import top.lingkang.finaloauth2.server.store.TokenStore;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.DelayQueue;
import java.util.concurrent.Delayed;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * @author lingkang
 * Created by 2022/3/24
 */
public class InMemoryTokenStore implements TokenStore, Serializable {
    private final ConcurrentHashMap<String, OAuth2AccessToken> accessTokenMap = new ConcurrentHashMap<>();
    private final ConcurrentHashMap<String, UserDetails> userDetailsMap = new ConcurrentHashMap<>();
    private final ConcurrentHashMap<String, OAuth2RefreshToken> refreshTokenMap = new ConcurrentHashMap<>();
    private final ConcurrentHashMap<String, String> tokenAndRefresh = new ConcurrentHashMap<>();
    // 维护用户与token关系
    private final ConcurrentHashMap<String, String> userAndToken = new ConcurrentHashMap<>();
    // 维护用户与刷新令牌关系
    private final ConcurrentHashMap<String, String> userAndRefresh = new ConcurrentHashMap<>();
    private AtomicInteger atomicInteger = new AtomicInteger(0);
    private static final int clearInterval = 500;
    private final DelayQueue<TokenExpiry> expiryQueue = new DelayQueue();

    // 授权码相关
    private final ConcurrentHashMap<String, Oauth2Code> authCode = new ConcurrentHashMap<>();
    private static final int AuthCodeClearInterval = 100;
    private AtomicInteger authAtomicInteger = new AtomicInteger(0);

    @Override
    public OAuth2AccessToken getAccessToken(String token) {
        OAuth2AccessToken oAuth2AccessToken = accessTokenMap.get(token);
        if (oAuth2AccessToken != null && oAuth2AccessToken.getExpiry() < System.currentTimeMillis()) {
            removeAccessToken(token);
            return null;
        }
        return oAuth2AccessToken;
    }

    @Override
    public UserTokenInfo getUserTokenInfo(String username, String clientId) {
        UserTokenInfo info = new UserTokenInfo();
        String token = userAndToken.get(username + "#" + clientId);
        if (token != null) {
            OAuth2AccessToken accessToken = getAccessToken(token);
            if (accessToken != null) {
                info.setToken(token);
                info.setTokenExpires(accessToken.getExpiry());
            }
        }
        String refreshToken = userAndRefresh.get(username + "##" + clientId);
        if (refreshToken != null) {
            OAuth2RefreshToken ref = getRefreshToken(refreshToken);
            if (ref != null) {
                info.setRefreshToken(refreshToken);
                info.setRefreshExpires(ref.getExpiry());
            }

        }
        return info;
    }

    @Override
    public OAuth2AccessToken removeAccessToken(String token) {
        tokenAndRefresh.remove(token);// 删除token与刷新令牌的关系
        UserDetails userDetails = userDetailsMap.get(token);
        if (userDetails != null) {// 删除用户与令牌关联
            userAndToken.remove(userDetails.getUsername() + "#" + userDetails.getClientId());
        }
        return accessTokenMap.remove(token);// 删除token
    }

    @Override
    public void storeAccessToken(OAuth2AccessToken accessToken, UserDetails userDetails) {
        if (atomicInteger.incrementAndGet() > clearInterval) {
            clearExpiry();
            atomicInteger.set(0);
        }
        expiryQueue.add(new TokenExpiry(accessToken.getToken(), accessToken.getExpiry()));
        accessTokenMap.put(accessToken.getToken(), accessToken);
        userDetailsMap.put(accessToken.getToken(), userDetails);
        userAndToken.put(userDetails.getUsername() + "#" + userDetails.getClientId(), accessToken.getToken());
    }

    @Override
    public UserDetails getUserDetails(String token) {
        OAuth2AccessToken oAuth2AccessToken = getAccessToken(token);
        if (oAuth2AccessToken != null) {
            return userDetailsMap.get(token);
        }
        return null;
    }

    @Override
    public OAuth2RefreshToken getRefreshToken(String refreshToken) {
        OAuth2RefreshToken oAuth2RefreshToken = refreshTokenMap.get(refreshToken);
        if (oAuth2RefreshToken != null && oAuth2RefreshToken.getExpiry() < System.currentTimeMillis()) {
            // 到期的刷新令牌
            removeRefreshToken(refreshToken);
            return null;
        }
        return oAuth2RefreshToken;
    }

    @Override
    public OAuth2RefreshToken removeRefreshToken(String refreshToken) {
        OAuth2RefreshToken remove = refreshTokenMap.remove(refreshToken);
        if (remove != null) {
            removeAccessToken(remove.getAccessToken().getToken());
            userAndRefresh.remove(remove.getUserDetails().getUsername() + "##" + remove.getUserDetails().getClientId());
        }
        return remove;
    }

    @Override
    public void storeRefreshToken(OAuth2RefreshToken refreshToken) {
        expiryQueue.add(new TokenExpiry(refreshToken.getRefreshToken(), refreshToken.getExpiry()));
        refreshTokenMap.put(refreshToken.getRefreshToken(), refreshToken);
        tokenAndRefresh.put(refreshToken.getAccessToken().getToken(), refreshToken.getRefreshToken());
        userAndRefresh.put(refreshToken.getUserDetails().getUsername() + "##" + refreshToken.getUserDetails().getClientId(),
                refreshToken.getRefreshToken());
    }

    @Override
    public void setCode(Oauth2Code oauth2Code) {
        if (authAtomicInteger.incrementAndGet() > AuthCodeClearInterval) {
            clearAuthCode();
            authAtomicInteger.set(0);
        }
        authCode.put(oauth2Code.getCode(), oauth2Code);
    }

    @Override
    public Oauth2Code getOauth2Code(String code) {
        Oauth2Code remove = authCode.remove(code);
        if (remove != null) {
            if (remove.getExpiry() > System.currentTimeMillis()) {
                return remove;
            }
        }
        return null;
    }

    public void clearAll() {
        accessTokenMap.clear();
        userAndToken.clear();
        userAndRefresh.clear();
        userDetailsMap.clear();
        refreshTokenMap.clear();
        tokenAndRefresh.clear();
        expiryQueue.clear();
        atomicInteger.set(0);
        authCode.clear();
        authAtomicInteger.set(0);
    }

    private void clearAuthCode() {
        if (authCode.size() < 0) {
            return;
        }
        long current = System.currentTimeMillis();
        List<String> keys = new ArrayList<>();
        for (Map.Entry<String, Oauth2Code> entry : authCode.entrySet()) {
            Oauth2Code value = entry.getValue();
            if (value.getExpiry() < current) {
                keys.add(entry.getKey());
            }
        }
        for (String key : keys) {
            authCode.remove(key);
        }
    }

    public int getAllToken() {
        return accessTokenMap.size();
    }

    public int getAllRefreshToken() {
        return refreshTokenMap.size();
    }

    private void clearExpiry() {
        for (TokenExpiry expiry = expiryQueue.poll(); expiry != null; expiry = expiryQueue.poll()) {
            // 将拿出到期的token进行移除
            removeAccessToken(expiry.getToken());// 可能是token
            removeRefreshToken(expiry.getToken());// 可能是刷新token
        }
    }

    private static class TokenExpiry implements Delayed {
        private final long expiry;
        private final String token;

        public TokenExpiry(String token, long expiry) {
            this.token = token;
            this.expiry = expiry;
        }

        /**
         * 需要实现的接口，获得延迟时间   用过期时间-当前时间
         */
        @Override
        public long getDelay(TimeUnit unit) {
            return unit.convert(expiry - System.currentTimeMillis(), TimeUnit.MILLISECONDS);
        }

        /**
         * 用于延迟队列内部比较排序   当前时间的延迟时间 - 比较对象的延迟时间
         */
        @Override
        public int compareTo(Delayed o) {
            return (int) (this.getDelay(TimeUnit.MILLISECONDS) - o.getDelay(TimeUnit.MILLISECONDS));
        }

        /**
         * 获取token的值
         *
         * @return
         */
        public String getToken() {
            return this.token;
        }
    }

}
